from botasaurus import *
from botasaurus.browser_decorator import browser
from botasaurus_driver.user_agent import UserAgent
import csv
import os
import time
import random


# --- 辅助函数 ---

def load_restaurant_urls(urls_file):
    """
    从CSV文件加载餐厅URL列表。
    """
    restaurant_urls = []
    try:
        with open(urls_file, mode='r', encoding='utf-8', newline='') as file:
            reader = csv.DictReader(file)
            for row in reader:
                if 'city' in row and 'url' in row:
                    restaurant_urls.append({
                        'city': row['city'],
                        'url': row['url'].split('#')[0]
                    })
        print(f"成功加载 {len(restaurant_urls)} 个餐厅URL。")
        return restaurant_urls
    except FileNotFoundError:
        print(f"错误: 无法找到文件 '{urls_file}'。")
        return []
    except Exception as e:
        print(f"读取 '{urls_file}' 时出错: {e}")
        return []


def load_checkpoint(checkpoint_file):
    """
    加载断点文件，返回最后处理的URL索引。
    """
    if os.path.exists(checkpoint_file):
        try:
            with open(checkpoint_file, mode='r', encoding='utf-8') as file:
                return int(file.read().strip())
        except:
            print(f"警告: 无法读取断点文件 '{checkpoint_file}'，从头开始。")
    return 0


def save_checkpoint(checkpoint_file, index):
    """
    保存当前处理的URL索引到断点文件。
    """
    try:
        with open(checkpoint_file, mode='w', encoding='utf-8') as file:
            file.write(str(index))
    except Exception as e:
        print(f"错误: 无法保存断点到 '{checkpoint_file}': {e}")


def save_restaurant_data(filepath, data):
    """
    保存餐厅数据到CSV文件。
    """
    fieldnames = [
        '城市', '品牌名称', '评分', '平均价格', '一级品类', '二级品类', '评论数目', '电话', '邮箱', '品牌链接', '地址']
    try:
        # 检查文件是否存在且不为空
        file_exists = os.path.exists(filepath) and os.path.getsize(filepath) > 0

        # 使用 'a' (追加) 模式打开文件，如果文件不存在会自动创建
        with open(filepath, mode='a', encoding='utf-8', newline='') as file:
            writer = csv.DictWriter(file, fieldnames=fieldnames)

            # 如果文件是新创建的，则写入表头
            if not file_exists:
                writer.writeheader()

            # 写入数据行，空值会被自动处理
            writer.writerow(data)
            file.flush()  # 确保数据立即写入磁盘
    except Exception as e:
        print(f"错误: 无法写入数据到 '{filepath}': {e}")

@browser(
    user_agent=UserAgent.REAL,
    chrome_executable_path=r"C:\Program Files\Google\Chrome\Application\chrome.exe",
    cache=True,
    block_images=True,
    lang="zh-CN",
    reuse_driver=True,
    close_on_crash=True,
    create_error_logs=True,
    headless=False,  # 设置为False以便观察模拟行为
)
def scrape_restaurant_details(driver, data, country=None):
    # 从data参数或直接参数获取国家
    if country is None and data is not None:
        country = data.get("country")
    if country is None:
        raise ValueError("必须提供 country 参数或在 data 中包含 country 键")
        
    restaurant_urls_file = f"{country}_restaurant_urls.csv"
    output_file = f"{country}_restaurant_details.csv"
    checkpoint_file = f"{country}_checkpoint.txt"
    requests_per_minute = 2
    min_delay = 60 / requests_per_minute
    
    print(f"🌍 开始处理国家: {country}")
    print(f"📁 输入文件: {restaurant_urls_file}")
    print(f"📁 输出文件: {output_file}")
    print(f"📁 断点文件: {checkpoint_file}")
    
    # 加载餐厅URL列表
    restaurant_urls = load_restaurant_urls(restaurant_urls_file)
    if not restaurant_urls:
        print("没有可处理的餐厅URL，程序退出。")
        return

    # 加载断点
    start_index = load_checkpoint(checkpoint_file)
    if start_index >= len(restaurant_urls):
        print("所有URL已处理完成。")
        return

    print(f"从索引 {start_index} 开始处理，共 {len(restaurant_urls)} 个URL。")

    for index in range(start_index, len(restaurant_urls)):
        task = restaurant_urls[index]
        city = task['city']
        url = task['url']
        print(f"\n--- 开始处理: {city} - {url} ---")

        try:
            # 加载页面
            driver.get(url)
            time.sleep(random.uniform(min_delay, min_delay + 3))
            # 提取数据
            data = {'城市': city.replace('餐厅', ''), '品牌链接': url}

            # 品牌名称
            try:
                brand_name = driver.select("h1.biGQs._P.hzzSG").text
                data['品牌名称'] = brand_name
            except:
                data['品牌名称'] = ''
                print("无法提取品牌名称。")

            # 评分
            try:
                rating = driver.select('div[data-automation="bubbleRatingValue"]').text
                data['评分'] = rating
            except:
                data['评分'] = ''
                print("无法提取评分。")

            # 客单均价
            try:
                price_range_text = driver.select("span.HUMGB.cPbcf span.bTeln:last-of-type a").text
                if '-' in price_range_text and any(char.isdigit() for char in price_range_text):
                    price_range_text = price_range_text.replace('¥', '')
                    prices = []
                    # 更安全地解析价格，避免非数字部分导致错误
                    for p in price_range_text.split('-'):
                        cleaned_p = ''.join(filter(str.isdigit, p))
                        if cleaned_p:
                            prices.append(int(cleaned_p))

                    # 确保列表不为空再进行除法运算
                    if prices:
                        data['平均价格'] = sum(prices) / len(prices)
                    else:
                        data['平均价格'] = price_range_text  # 如果解析失败，则保留原始文本
                else:
                    data['平均价格'] = price_range_text
            except:
                data['平均价格'] = ''
                print("无法提取价格范围。")

            # 一级品类
            try:
                category = driver.select("span.HUMGB.cPbcf a:last-of-type").text
                data['一级品类'] = category
            except:
                data['一级品类'] = ''
                print("无法提取一级品类。")

            # 一级品类
            try:
                category_lv2 = driver.select_all("span.HUMGB.cPbcf a")[1].text
                data['二级品类'] = category_lv2
            except:
                data['二级品类'] = ''
                print("无法提取二级品类。")

            # 评论数目
            try:
                review_count_text = driver.select(
                    'div.CsAqy a[href="#REVIEWS"] div[data-automation="bubbleReviewCount"]').text
                review_count = ''.join(filter(str.isdigit, review_count_text))
                data['评论数目'] = review_count
            except:
                data['评论数目'] = ''
                print("无法提取评论数目。")

            # 门店联系电话
            try:
                phone = driver.select("a[href^='tel:']").text.replace(' ', '')
                data['电话'] = phone
            except:
                data['电话'] = ''
                print("无法提取电话。")

            # 品牌地址
            try:
                address = driver.select("button.Tbrbj span").text
                data['地址'] = address
            except:
                data['地址'] = ''
                print("无法提取地址。")

            # 邮箱
            try:
                mail = driver.select("a[href^='mailto:']").get_attribute('href')
                data['邮箱'] = mail.replace('mailto:', '')
            except:
                data['邮箱'] = ''
                print("无法提取邮箱。")

            # 保存数据
            save_restaurant_data(output_file, data)
            print(f"成功保存数据: {data['品牌名称']} ({url})")
            # 更新断点
            save_checkpoint(checkpoint_file, index + 1)
        except Exception as e:
            print(f"处理 {url} 时出错: {e}")
            # 继续下一个URL，避免程序中断

    print("所有餐厅URL处理完成。")


# --- 运行爬虫 ---

if __name__ == "__main__":
    start_time = time.time()

    # 1. 定义你要处理的国家列表
    countries_to_scrape = [ "乌兹别克斯坦"]

    # 2. 循环处理每个国家
    for country in countries_to_scrape:
        print(f"\n=========================================")
        print(f"开始处理国家: {country}")
        print(f"=========================================")

        # 确保该国家的餐厅URL文件存在
        restaurant_urls_file = f"{country}_restaurant_urls.csv"
        if not os.path.exists(restaurant_urls_file):
            print(f"错误: 找不到餐厅URL文件 '{restaurant_urls_file}'。请先准备好该文件。")
            continue # 跳过这个国家，继续下一个

        # 3. 调用主函数，并将 country 作为参数传入
        # 注意：Botasaurus v4的调用方式
        scrape_restaurant_details(data={"country": country})

    print("\n所有国家的任务都已处理完毕。")
    print(f"总耗时: {time.time() - start_time:.2f} 秒。")